-
Notifications
You must be signed in to change notification settings - Fork 4
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Questions about data loading and training #5
Comments
Thank you for your interest in Additionally, I discovered an error in the hyperparameters listed in my paper, thanks to your comment. Without your input, I might not have noticed this mistake. I also recommend using a train_base_learning_rate of 1e-5. If you have any further questions or if the training still does not work as expected, please feel free to contact me via email at dlcjfgmlnasa28@korea.ac.kr, along with the training logs. Thank you again! # -*- coding:utf-8 -*-
import os
import re
import abc
import mne
import pickle
import numpy as np
from typing import Dict
class Base(object):
def __init__(self, labels: Dict, sfreq: float):
self.labels = labels
self.sfreq = sfreq
@abc.abstractmethod
def parser(self, path) -> (np.array, np.array):
pass
def save(self, scr_path: str, trg_path: str):
x, y = self.parser(scr_path)
np.savez(trg_path, x=x, y=y, sfreq=self.sfreq)
class SHHS(Base):
def __init__(self, labels: Dict, sfreq: float):
super().__init__(labels=labels, sfreq=sfreq)
def parser(self, path) -> (np.array, np.array):
edf_data = mne.io.read_raw_edf(path, preload=True)
idx = edf_data.ch_names.index('EEG')
data = edf_data.get_data()[idx]
x = np.reshape(data, [-1, 30 * self.sfreq])
name_ = os.path.basename(path).split('.')[0] + '-nsrr.xml'
label_path = os.path.join(*path.split('/')[:-3], 'annotations-events-nsrr', 'shhs1', name_)
y = self.read_annotation_regex(label_path)
y = np.array(y)
y = np.array([self.labels[str(y_)] for y_ in y])
return x, y
@staticmethod
def read_annotation_regex(filename):
with open(filename, 'r') as f:
content = f.read()
patterns_stages = re.findall(
r'<EventType>Stages.Stages</EventType>\n' +
r'<EventConcept>.+</EventConcept>\n' +
r'<Start>[0-9\.]+</Start>\n' +
r'<Duration>[0-9\.]+</Duration>',
content)
stages, starts, durations = [], [], []
for pattern in patterns_stages:
lines = pattern.splitlines()
stage_line = lines[1]
stage = int(stage_line[-16])
start_line = lines[2]
start = float(start_line[7:-8])
duration_line = lines[3]
duration = float(duration_line[10:-11])
assert duration % 30 == 0.
epochs_duration = int(duration) // 30
stages += [stage]*epochs_duration
starts += [start]
durations += [duration]
return stages
if __name__ == '__main__':
src_base_path_ = os.path.join('..', '..', '..', '..', 'Dataset', 'SHHS', 'polysomnography', 'edfs', 'shhs1')
trg_base_path_ = os.path.join('..', 'data', 'stage', 'SHHS')
# 0 - Wake | 1 - Stage1 | 2 - Stage2 | 3 - Stage 3/4 | 4 - Stage 3/4 | 5 - REM stage | 9 - Movement/Wake
dataset = SHHS(labels={'0': 0, '1': 1, '2': 2, '3': 3, '4': 3, '5': 4, '9': 0}, sfreq=125)
for name in open(os.path.join('.', 'selected_shhs1_files.txt')).readlines():
name = name.strip() + '.edf'
src_path_ = os.path.join(src_base_path_, name)
trg_path_ = os.path.join(trg_base_path_, name.split('.')[0] + '.npz')
dataset.save(src_path_, trg_path_) If you found this helpful, please consider giving it a star! ⭐ |
Thank you for providing the code; it has been very helpful in understanding the paper.
I am trying to reproduce the work on the SHHS dataset based on the hyperparameters provided in the paper, but it seems that I am encountering some issues.
Following the selected_shhs1_files.txt, I read the EEG signals from the corresponding EDF files and the labels from the XML files, modifying the labels by changing 4 to 3 and 5 to 4. These are the only data processing steps I have made.
Next, I proceeded with pretraining using the following command:
The only uncertain part is the
batch_size
, which I reduced slightly (I’m sorry, I cannot find the previous record), but I kept the base_learning_rate at2e-5
since I saw that the code includes logic to adjust the learning rate.Next is the fine-tuning stage with the following commands:
I found that the default parameters match those described in the paper, so I only modified the batch_size.
Currently, the results I’m achieving from the replication are not very good. I’m wondering if there are any misunderstandings in the data processing or training stages. I look forward to your feedback.
The text was updated successfully, but these errors were encountered: