-
Notifications
You must be signed in to change notification settings - Fork 10
/
data.py
289 lines (257 loc) · 11.3 KB
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
### Data class and associated helper methods
import numpy as np
import h5py
import os
import time
from threading import Thread
import itertools
class FilePreloader(Thread):
def __init__(self, files_list, file_open,n_ahead=2):
Thread.__init__(self)
self.deamon = True
self.n_concurrent = n_ahead
self.files_list = files_list
self.file_open = file_open
self.loaded = {} ## a dict of the loaded objects
self.should_stop = False
def getFile(self, name):
## locks until the file is loaded, then return the handle
return self.loaded.setdefault(name, self.file_open( name))
def closeFile(self,name):
## close the file and
if name in self.loaded:
self.loaded.pop(name).close()
def run(self):
while not self.files_list:
time.sleep(1)
for name in itertools.cycle(self.files_list):
if self.should_stop:
break
n_there = len(self.loaded.keys())
if n_there< self.n_concurrent:
print ("preloading",name,"with",n_there)
self.getFile( name )
else:
time.sleep(5)
def stop(self):
print("Stopping FilePreloader")
self.should_stop = True
def data_class_getter(name):
"""Returns the specified Data class"""
data_dict = {
"H5Data":H5Data,
}
try:
return data_dict[name]
except KeyError:
print ("{0:s} is not a known Data class. Returning None...".format(name))
return None
class Data(object):
"""Class providing an interface to the input training data.
Derived classes should implement the load_data function.
Attributes:
file_names: list of data files to use for training
batch_size: size of training batches
"""
def __init__(self, batch_size, cache=None, spectators=False):
"""Stores the batch size and the names of the data files to be read.
Params:
batch_size: batch size for training
"""
self.batch_size = batch_size
self.caching_directory = cache if cache else os.environ.get('GANINMEM','')
self.spectators = spectators
self.fpl = None
def set_caching_directory(self, cache):
self.caching_directory = cache
def set_file_names(self, file_names):
## hook to copy data in /dev/shm
relocated = []
if self.caching_directory:
goes_to = self.caching_directory
goes_to += str(os.getpid())
os.system('mkdir %s '%goes_to)
os.system('rm %s/* -f'%goes_to) ## clean first if anything
for fn in file_names:
relocate = goes_to+'/'+fn.split('/')[-1]
if not os.path.isfile( relocate ):
print ("copying %s to %s"%( fn , relocate))
if os.system('cp %s %s'%( fn ,relocate))==0:
relocated.append( relocate )
else:
print ("was enable to copy the file",fn,"to",relocate)
relocated.append( fn ) ## use the initial one
else:
relocated.append( relocate )
self.file_names = relocated
else:
self.file_names = file_names
if self.fpl:
self.fpl.files_list = self.file_names
def inf_generate_data(self):
while True:
try:
for B in self.generate_data():
yield B
except StopIteration as si:
print ("start over generator loop")
def inf_generate_data_keras(self):
while True:
try:
for B, C, D in self.generate_data():
yield [B[2].swapaxes(1,2), B[3].swapaxes(1,2)], C
except StopIteration as si:
print ("start over generator loop")
def inf_generate_data_keras_db(self):
while True:
try:
for B, C, D in self.generate_data():
yield [B[0], B[2].swapaxes(1,2)[:,:,22:], B[3].swapaxes(1,2)[:,:,11:13]], C
except StopIteration as si:
print ("start over generator loop")
def inf_generate_data_keras_cpf(self):
while True:
try:
for B, C, D in self.generate_data():
yield [B[2].swapaxes(1,2)], C
except StopIteration as si:
print ("start over generator loop")
def generate_data(self):
"""Yields batches of training data until none are left."""
leftovers = None
for cur_file_name in self.file_names:
if self.spectators:
cur_file_features, cur_file_labels, cur_file_spectators = self.load_data(cur_file_name)
else:
cur_file_features, cur_file_labels = self.load_data(cur_file_name)
# concatenate any leftover data from the previous file
if leftovers is not None:
cur_file_features = self.concat_data( leftovers[0], cur_file_features )
cur_file_labels = self.concat_data( leftovers[1], cur_file_labels )
if self.spectators:
cur_file_spectators = self.concat_data( leftovers[2], cur_file_spectators)
leftovers = None
num_in_file = self.get_num_samples( cur_file_features )
for cur_pos in range(0, num_in_file, self.batch_size):
next_pos = cur_pos + self.batch_size
if next_pos <= num_in_file:
if self.spectators:
yield ( self.get_batch( cur_file_features, cur_pos, next_pos ),
self.get_batch( cur_file_labels, cur_pos, next_pos ),
self.get_batch( cur_file_spectators, cur_pos, next_pos ) )
else:
yield ( self.get_batch( cur_file_features, cur_pos, next_pos ),
self.get_batch( cur_file_labels, cur_pos, next_pos ) )
else:
if self.spectators:
leftovers = ( self.get_batch( cur_file_features, cur_pos, num_in_file ),
self.get_batch( cur_file_labels, cur_pos, num_in_file ),
self.get_batch( cur_file_spectators, cur_pos, num_in_file) )
else:
leftovers = ( self.get_batch( cur_file_features, cur_pos, num_in_file ),
self.get_batch( cur_file_labels, cur_pos, num_in_file ) )
def count_data(self):
"""Counts the number of data points across all files"""
num_data = 0
for cur_file_name in self.file_names:
cur_file_features, cur_file_labels = self.load_data(cur_file_name)
num_data += self.get_num_samples( cur_file_features )
return num_data
def is_numpy_array(self, data):
return isinstance( data, np.ndarray )
def get_batch(self, data, start_pos, end_pos):
"""Input: a numpy array or list of numpy arrays.
Gets elements between start_pos and end_pos in each array"""
if self.is_numpy_array(data):
return data[start_pos:end_pos]
else:
return [ arr[start_pos:end_pos] for arr in data ]
def concat_data(self, data1, data2):
"""Input: data1 as numpy array or list of numpy arrays. data2 in the same format.
Returns: numpy array or list of arrays, in which each array in data1 has been
concatenated with the corresponding array in data2"""
if self.is_numpy_array(data1):
return np.concatenate( (data1, data2) )
else:
return [ self.concat_data( d1, d2 ) for d1,d2 in zip(data1,data2) ]
def get_num_samples(self, data):
"""Input: dataset consisting of a numpy array or list of numpy arrays.
Output: number of samples in the dataset"""
if self.is_numpy_array(data):
return len(data)
else:
return len(data[0])
def load_data(self, in_file):
"""Input: name of file from which the data should be loaded
Returns: tuple (X,Y) where X and Y are numpy arrays containing features
and labels, respectively, for all data in the file
Not implemented in base class; derived classes should implement this function"""
raise NotImplementedError
class H5Data(Data):
"""Loads data stored in hdf5 files
Attributes:
features_name, labels_name, spectators_name: names of the datasets containing the features,
labels, and spectators respectively
"""
def __init__(self, batch_size,
cache=None,
preloading=0,
features_name='features',
labels_name='labels',
spectators_name = None):
"""Initializes and stores names of feature and label datasets"""
super(H5Data, self).__init__(batch_size,cache,(spectators_name is not None))
self.features_name = features_name
self.labels_name = labels_name
self.spectators_name = spectators_name
## initialize the data-preloader
self.fpl = None
if preloading:
self.fpl = FilePreloader( [] , file_open = lambda n : h5py.File(n,'r'), n_ahead=preloading)
self.fpl.start()
def load_data(self, in_file_name):
"""Loads numpy arrays from H5 file.
If the features/labels groups contain more than one dataset,
we load them all, alphabetically by key."""
if self.fpl:
h5_file = self.fpl.getFile( in_file_name )
else:
h5_file = h5py.File( in_file_name, 'r' )
X = self.load_hdf5_data( h5_file[self.features_name] )
Y = self.load_hdf5_data( h5_file[self.labels_name] )
if self.spectators_name is not None:
Z = self.load_hdf5_data( h5_file[self.spectators_name] )
if self.fpl:
self.fpl.closeFile( in_file_name )
else:
h5_file.close()
if self.spectators_name is not None:
return X,Y,Z
else:
return X,Y
def load_hdf5_data(self, data):
"""Returns a numpy array or (possibly nested) list of numpy arrays
corresponding to the group structure of the input HDF5 data.
If a group has more than one key, we give its datasets alphabetically by key"""
if hasattr(data, 'keys'):
out = [ self.load_hdf5_data( data[key] ) for key in sorted(data.keys()) ]
else:
out = data[:]
return out
def count_data(self):
"""This is faster than using the parent count_data
because the datasets do not have to be loaded
as numpy arrays"""
num_data = 0
for in_file_name in self.file_names:
h5_file = h5py.File( in_file_name, 'r' )
X = h5_file[self.features_name]
if hasattr(X, 'keys'):
num_data += len(X[ list(X.keys())[0] ])
else:
num_data += len(X)
h5_file.close()
return num_data
def finalize(self):
if self.fpl:
self.fpl.stop()