-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdataloader.py
More file actions
93 lines (66 loc) · 2.87 KB
/
dataloader.py
File metadata and controls
93 lines (66 loc) · 2.87 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import numpy as np
import torch
import torch.utils.data as utils
import csv
from nilearn.connectome import ConnectivityMeasure
from sklearn import preprocessing
import pandas as pd
import matplotlib.pyplot as plt
from scipy.io import loadmat
from nilearn import plotting, datasets
import random
class StandardScaler:
"""
Standard the input
"""
def __init__(self, mean, std):
self.mean = mean
self.std = std
def transform(self, data):
return (data - self.mean) / self.std
def inverse_transform(self, data):
return (data * self.std) + self.mean
def init_dataloader(dataset_config):
data = np.load(dataset_config["time_seires"], allow_pickle=True).item()
final_fc = data["timeseires"]
final_pearson = data["corr"]
labels = data["label"]
_, _, timeseries = final_fc.shape
_, node_size, node_feature_size = final_pearson.shape
scaler = StandardScaler(mean=np.mean(
final_fc), std=np.std(final_fc))
final_fc = scaler.transform(final_fc)
pseudo = []
for i in range(len(final_fc)):
pseudo.append(np.diag(np.ones(final_pearson.shape[1])))
if 'cc200' in dataset_config['atlas']:
if dataset_config["dataset"] == "ADHD":
pseudo_arr = np.concatenate(pseudo, axis=0).reshape((-1, 190, 190))
else:
pseudo_arr = np.concatenate(pseudo, axis=0).reshape((-1, 200, 200))
elif 'aal' in dataset_config['atlas']:
pseudo_arr = np.concatenate(pseudo, axis=0).reshape((-1, 116, 116))
elif 'cc400' in dataset_config['atlas']:
pseudo_arr = np.concatenate(pseudo, axis=0).reshape((-1, 392, 392))
else:
pseudo_arr = np.concatenate(pseudo, axis=0).reshape((-1, 111, 111))
final_fc, final_pearson, labels, pseudo_arr = [torch.from_numpy(
data).float() for data in (final_fc, final_pearson, labels, pseudo_arr)]
length = final_fc.shape[0]
train_length = int(length*dataset_config["train_set"])
val_length = int(length*dataset_config["val_set"])
dataset = utils.TensorDataset(
final_fc,
final_pearson,
labels,
pseudo_arr
)
train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(
dataset, [train_length, val_length, length-train_length-val_length])
train_dataloader = utils.DataLoader(
train_dataset, batch_size=dataset_config["batch_size"], shuffle=True, drop_last=False)
val_dataloader = utils.DataLoader(
val_dataset, batch_size=dataset_config["batch_size"], shuffle=True, drop_last=False)
test_dataloader = utils.DataLoader(
test_dataset, batch_size=dataset_config["batch_size"], shuffle=True, drop_last=False)
return (train_dataloader, val_dataloader, test_dataloader), node_size, node_feature_size, timeseries