forked from athon2/BraTS2018_NvNet
-
Notifications
You must be signed in to change notification settings - Fork 0
/
data_preprocess.py
126 lines (110 loc) · 6.14 KB
/
data_preprocess.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import os
import glob
import numpy as np
import tables
from random import shuffle
from utils.normalize import normalize_data_storage, reslice_image_set
from utils import pickle_dump, pickle_load
from main import config
def create_data_file(out_file, n_channels, n_samples, image_shape):
hdf5_file = tables.open_file(out_file, mode='w')
filters = tables.Filters(complevel=5, complib='blosc')
data_shape = tuple([0, n_channels] + list(image_shape))
truth_shape = tuple([0, 1] + list(image_shape))
data_storage = hdf5_file.create_earray(hdf5_file.root, 'data', tables.Float32Atom(), shape=data_shape, filters=filters, expectedrows=n_samples)
truth_storage = hdf5_file.create_earray(hdf5_file.root, 'truth', tables.UInt8Atom(), shape=truth_shape, filters=filters, expectedrows=n_samples)
affine_storage = hdf5_file.create_earray(hdf5_file.root, 'affine', tables.Float32Atom(), shape=(0, 4, 4),filters=filters, expectedrows=n_samples)
return hdf5_file, data_storage, truth_storage, affine_storage
def write_image_data_to_file(image_files, data_storage, truth_storage, image_shape, n_channels, affine_storage, truth_dtype=np.uint8, crop=True, label_indices=None, save_truth=True):
for set_of_files in image_files:
if label_indices is None:
_label_indices = len(set_of_files) - 1
else:
_label_indices = label_indices
images = reslice_image_set(set_of_files, image_shape, label_indices=_label_indices, crop=crop)
subject_data = [image.get_data() for image in images]
add_data_to_storage(data_storage, truth_storage, affine_storage, subject_data, images[0].affine, n_channels, truth_dtype, save_truth=save_truth)
return data_storage, truth_storage
def add_data_to_storage(data_storage, truth_storage, affine_storage, subject_data, affine, n_channels, truth_dtype, save_truth=True):
data_storage.append(np.asarray(subject_data[:n_channels])[np.newaxis])
if save_truth:
truth_storage.append(np.asarray(subject_data[n_channels], dtype=truth_dtype)[np.newaxis][np.newaxis])
affine_storage.append(np.asarray(affine)[np.newaxis])
def write_data_to_file(training_data_files, out_file, image_shape, truth_dtype=np.uint8, subject_ids=None, normalize=True, crop=True, save_truth=True):
"""
Takes in a set of training images and writes those images to an hdf5 file.
:param training_data_files: List of tuples containing the training data files. The modalities should be listed in
the same order in each tuple. The last item in each tuple must be the labeled image. If the label image is not
available, set save_truth to False.
Example: [('sub1-T1.nii.gz', 'sub1-T2.nii.gz', 'sub1-truth.nii.gz'),
('sub2-T1.nii.gz', 'sub2-T2.nii.gz', 'sub2-truth.nii.gz')]
:param out_file: Where the hdf5 file will be written to.
:param image_shape: Shape of the images that will be saved to the hdf5 file.
:param truth_dtype: Default is 8-bit unsigned integer.
:return: Location of the hdf5 file with the image data written to it.
"""
n_samples = len(training_data_files)
n_channels = len(training_data_files[0])
if save_truth:
n_channels = n_channels - 1
try:
hdf5_file, data_storage, truth_storage, affine_storage = create_data_file(out_file,
n_channels=n_channels, n_samples=n_samples, image_shape=image_shape)
except Exception as e:
# If something goes wrong, delete the incomplete data file
os.remove(out_file)
raise e
label_indices = None
if not save_truth:
label_indices = []
write_image_data_to_file(training_data_files, data_storage, truth_storage, image_shape, truth_dtype=truth_dtype, n_channels=n_channels, affine_storage=affine_storage, crop=crop, label_indices=label_indices, save_truth=save_truth)
if subject_ids:
hdf5_file.create_array(hdf5_file.root, 'subject_ids', obj=subject_ids)
if normalize:
normalize_data_storage(data_storage)
hdf5_file.close()
return out_file
def open_data_file(filename, readwrite="r"):
return tables.open_file(filename, readwrite)
def split_list(input_list, split=0.8, shuffle_list=True):
if shuffle_list:
shuffle(input_list)
n_training = int(len(input_list) * split)
training = input_list[:n_training]
testing = input_list[n_training:]
return training, testing
def get_validation_split(data_file, training_file, validation_file, data_split=0.8, overwrite=False):
"""
"""
if overwrite or not os.path.exists(training_file):
print("Creating validation split...")
nb_samples = data_file.root.data.shape[0]
sample_list = list(range(nb_samples))
training_list, validation_list = split_list(sample_list, split=data_split)
pickle_dump(training_list, training_file)
pickle_dump(validation_list, validation_file)
return training_list, validation_list
else:
print("Loading previous validation split...")
return pickle_load(training_file), pickle_load(validation_file)
def fetch_training_data_files(data_dir, return_subject_ids=True):
training_data_files = list()
subject_ids = list()
for subject_dir in glob.glob(os.path.join(data_dir, "*", "*")):
subject_ids.append(os.path.basename(subject_dir))
subject_files = list()
for modality in config["all_modalities"]+["truth"]:
subject_files.append(os.path.join(subject_dir, modality + ".nii.gz"))
training_data_files.append(tuple(subject_files))
if return_subject_ids:
return training_data_files, subject_ids
else:
return training_data_files
if __name__ =='__main__':
data_dir = os.path.join(os.path.dirname(__file__), "data")
training_files, subject_ids = fetch_training_data_files(data_dir, return_subject_ids=True)
# print(training_files)
write_data_to_file(training_files, config["data_file"], image_shape=config["image_shape"],
subject_ids=subject_ids)
data_file_opened = open_data_file(config["data_file"])
get_validation_split(data_file_opened, config["training_file"],config["validation_file"])