-
Notifications
You must be signed in to change notification settings - Fork 50
/
DataSetSamplingPascal.lua
561 lines (506 loc) · 20.9 KB
/
DataSetSamplingPascal.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
--------------------------------------------------------------------------------
-- DataSetSamplingPascal: A class to handle datasets from LabelMe (and other segmentation
-- based datasets).
--
-- Provides lots of options to cache (on disk) datasets, precompute
-- segmentation masks, shuffle samples, extract subpatches, ...
--
-- Authors: Clement Farabet, Benoit Corda
--------------------------------------------------------------------------------
local DataSetSamplingPascal = torch.class('nn.DataSetSamplingPascal')
local path_images = 'Images'
local path_annotations = 'Annotations'
local path_masks = 'Masks'
function DataSetSamplingPascal:__init(...)
-- check args
xlua.unpack_class(
self,
{...},
'DataSetSamplingPascal',
'Creates a DataSet from standard LabelMe directories (Images+Annotations)',
{arg='path', type='string', help='path to LabelMe directory', req=true},
{arg='nbClasses', type='number', help='number of classes in dataset', default=1},
{arg='classNames', type='table', help='list of class names', default={'no name'}},
{arg='nbRawSamples', type='number', help='number of images'},
{arg='nbSegments', type='number', help='number of segment per image in dataset', default=100},
{arg='rawSampleMaxSize', type='number', help='resize all images to fit in a MxM window'},
{arg='rawSampleSize', type='table', help='resize all images precisely: {w=,h=}}'},
{arg='rawMaskRescale',type='boolean',help='does are the N classes spread between 0->255 in the PNG and need to be rescaled',default=true},
{arg='nbPatchPerSample', type='number', help='number of patches to extract from each image', default=100},
{arg='patchSize', type='number', help='size of patches to extract from images', default=64},
{arg='samplingMode', type='string', help='patch sampling method: random | equal', default='random'},
{arg='samplingFilter', type='table', help='a filter to sample patches: {ratio=,size=,step}'},
{arg='labelType', type='string', help='type of label returned: center | pixelwise', default='center'},
{arg='labelGenerator', type='function', help='a function to generate sample+target (bypasses labelType)'},
{arg='infiniteSet', type='boolean', help='if true, the set can be indexed to infinity, looping around samples', default=false},
{arg='classToSkip', type='number', help='index of class to skip during sampling', default=0},
{arg='preloadSamples', type='boolean', help='if true, all samples are preloaded in memory', default=false},
{arg='cacheFile', type='string', help='path to cache file (once cached, loading is much faster)'},
{arg='verbose', type='boolean', help='dumps information', default=false},
{arg='nbSegmentsToExtract', type='number', help='number of segments per image to be extracted', default=1}
)
-- fixed parameters
self.colorMap = image.colormap(self.nbClasses)
self.rawdata = {}
self.currentIndex = -1
--location of the patch in the img
self.currentX = 0
self.currentY = 0
self.realIndex = -1
self.currentSegment = 0
-- parse dir structure
print('<DataSetSamplingPascal> loading LabelMe dataset from '..self.path)
for folder in paths.files(paths.concat(self.path,path_images)) do
if folder ~= '.' and folder ~= '..' then
-- allowing for less nesting in the data set preparation [MS]
if sys.filep(paths.concat(self.path,path_images,folder)) then
self:getsizes('./',folder)
else
-- loop though nested folders
for file in paths.files(paths.concat(self.path,path_images,folder)) do
if file ~= '.' and file ~= '..' then
self:getsizes(folder,file)
end
end
end
end
end
-- nb samples: user defined or max
self.nbRawSamples = self.nbRawSamples or #self.rawdata
-- extract some info (max sizes)
self.maxY = self.rawdata[1].size[2]
self.maxX = self.rawdata[1].size[3]
for i = 2,self.nbRawSamples do
if self.maxX < self.rawdata[i].size[3] then
self.maxX = self.rawdata[i].size[3]
end
if self.maxY < self.rawdata[i].size[2] then
self.maxY = self.rawdata[i].size[2]
end
end
-- and nb of samples obtainable (this is overcomplete ;-)
self.nbSamples = self.nbPatchPerSample * self.nbRawSamples
-- max size ?
local maxXY = math.max(self.maxX, self.maxY)
if not self.rawSampleMaxSize then
if self.rawSampleSize then
self.rawSampleMaxSize =
math.max(self.rawSampleSize.w,self.rawSampleSize.h)
else
self.rawSampleMaxSize = maxXY
end
end
if maxXY < self.rawSampleMaxSize then
self.rawSampleMaxSize = maxXY
end
-- some info
if self.verbose then
print(self)
end
-- sampling mode
if self.samplingMode == 'equal' or self.samplingMode == 'random' then
self:parseAllMasks()
if self.samplingMode == 'random' then
-- get the number of usable patches
self.nbRandomPatches = 0
for i,v in ipairs(self.tags) do
if i ~= self.classToSkip then
self.nbRandomPatches = self.nbRandomPatches + v.size
end
end
-- create shuffle table
self.randomLookup = torch.ByteTensor(self.nbRandomPatches)
local idx = 1
for i,v in ipairs(self.tags) do
if i ~= self.classToSkip and v.size > 0 then
self.randomLookup:narrow(1,idx,v.size):fill(i)
idx = idx + v.size
end
end
end
else
error('ERROR <DataSetSamplingPascal> unknown sampling mode')
end
-- preload ?
if self.preloadSamples then
self:preload()
end
end
function DataSetSamplingPascal:getsizes(folder,file)
local filepng = file:gsub('jpg$','png')
local filexml = file:gsub('jpg$','xml')
local imgf = paths.concat(self.path,path_images,folder,file)
local maskf = paths.concat(self.path,path_masks,folder,filepng)
local annotf = paths.concat(self.path,path_annotations,folder,filexml)
local size_c, size_y, size_x
if file:find('.jpg$') then
size_c, size_y, size_x = image.getJPGsize(imgf)
elseif file:find('.png$') then
size_c, size_y, size_x = image.getPNGsize(imgf)
elseif file:find('.mat$') then
if not xrequire 'mattorch' then
xerror('<DataSetSamplingPascal> mattorch package required to handle MAT files')
end
local loaded = mattorch.load(imgf)
for _,matrix in pairs(loaded) do loaded = matrix; break end
size_c = loaded:size(1)
size_y = loaded:size(2)
size_x = loaded:size(3)
loaded = nil
collectgarbage()
else
xerror('images must either be JPG, PNG or MAT files', 'DataSetSamplingPascal')
end
table.insert(self.rawdata, {imgfile=imgf,
maskfile=maskf,
annotfile=annotf,
size={size_c, size_y, size_x}})
end
function DataSetSamplingPascal:size()
return self.nbSamples
end
function DataSetSamplingPascal:__tostring__()
local str = 'DataSetSamplingPascal:\n'
str = str .. ' + path : '..self.path..'\n'
if self.cacheFile then
str = str .. ' + cache files : [path]/'..self.cacheFile..'-[tags|samples]\n'
end
str = str .. ' + nb samples : '..self.nbRawSamples..'\n'
str = str .. ' + nb generated patches : '..self.nbSamples..'\n'
if self.infiniteSet then
str = str .. ' + infinite set (actual nb of samples >> set:size())\n'
end
if self.rawSampleMaxSize then
str = str .. ' + samples are resized to fit in a '
str = str .. self.rawSampleMaxSize .. 'x' .. self.rawSampleMaxSize .. ' tensor'
str = str .. ' [max raw size = ' .. self.maxX .. 'x' .. self.maxY .. ']\n'
if self.rawSampleSize then
str = str .. ' + imposed ratio of ' .. self.rawSampleSize.w .. 'x' .. self.rawSampleSize.h .. '\n'
end
end
str = str .. ' + patches size : ' .. self.patchSize .. 'x' .. self.patchSize .. '\n'
if self.classToSkip ~= 0 then
str = str .. ' + unused class : ' .. self.classNames[self.classToSkip] .. '\n'
end
str = str .. ' + sampling mode : ' .. self.samplingMode .. '\n'
if not self.labelGenerator then
str = str .. ' + label type : ' .. self.labelType .. '\n'
else
str = str .. ' + label type : generated by user function \n'
end
str = str .. ' + '..self.nbClasses..' categories : '
for i = 1,#self.classNames-1 do
str = str .. self.classNames[i] .. ' | '
end
str = str .. self.classNames[#self.classNames]
return str
end
function DataSetSamplingPascal:__index__(key)
-- generate sample + target at index 'key':
if type(key)=='number' then
-- select sample, according to samplingMode
local box_size = self.patchSize
local ctr_target, tag_idx
if self.samplingMode == 'random' then
-- get indexes from random table
ctr_target = self.randomLookup[math.random(1,self.nbRandomPatches)]
tag_idx = math.floor(math.random(0,self.tags[ctr_target].size-1)/4)*4+1
elseif self.samplingMode == 'equal' then
-- equally sample each category:
ctr_target = ((key-1) % (self.nbClasses)) + 1
while self.tags[ctr_target].size == 0 or ctr_target == self.classToSkip do
-- no sample in that class, replacing with random patch
ctr_target = math.floor(torch.uniform(1,self.nbClasses))
end
local nbSamplesPerClass = math.ceil(self.nbSamples / self.nbClasses)
if self.infiniteSet then
tag_idx = math.random(1,self.tags[ctr_target].size/4)
else
tag_idx = math.floor((key-1)/self.nbClasses) + 1
end
tag_idx = ((tag_idx-1) % (self.tags[ctr_target].size/4))*4 + 1
end
-- generate patch
self:loadSample(self.tags[ctr_target].data[tag_idx+2])
local full_sample = self.currentSample
local full_mask = self.currentMask
local ctr_x = self.tags[ctr_target].data[tag_idx]
local ctr_y = self.tags[ctr_target].data[tag_idx+1]
local box_x = math.floor(ctr_x - box_size/2) + 1
self.currentX = box_x/full_sample:size(3)
local box_y = math.floor(ctr_y - box_size/2) + 1
self.currentY = box_y/full_sample:size(2)
self.currentSegment = self.tags[ctr_target].data[tag_idx+3]
-- extract sample + mask:
local sample = full_sample:narrow(2,box_y,box_size):narrow(3,box_x,box_size)
local mask = full_mask:narrow(1,box_y,box_size):narrow(2,box_x,box_size)
-- finally, generate the target, either using an arbitrary user function,
-- or a built-in label type
if self.labelGenerator then
-- call user function to generate sample+label
local ret = self:labelGenerator(full_sample, full_mask, sample, mask,
ctr_target, ctr_x, ctr_y, box_x, box_y, box_size,self.currentSegment,self.realIndex)
return ret, true
elseif self.labelType == 'center' then
-- generate label vector for patch
local vector = torch.Tensor(self.nbClasses):fill(-1)
vector[ctr_target] = 1
return {sample, vector}, true
elseif self.labelType == 'pixelwise' then
-- generate pixelwise annotation
return {sample, mask}, true
else
return false
end
end
return rawget(self,key)
end
function DataSetSamplingPascal:loadSample(index)
if self.preloadedDone then
if index ~= self.currentIndex then
-- load new sample
self.currentSample = self.preloaded.samples[index]
self.currentMask = self.preloaded.masks[index]
-- remember index
self.currentIndex = index
end
elseif index ~= self.currentIndex then
self.realIndex = self.rawdata[index].imgfile:gsub('.jpg$','')
-- clean up
self.currentSample = nil
self.currentMask = nil
collectgarbage()
-- matlab or regular images ?
local matlab = false
if self.rawdata[index].imgfile:find('.mat$') then
if not xrequire 'mattorch' then
xerror('<DataSetSamplingPascal> mattorch package required to handle MAT files')
end
matlab = true
end
-- load image
local img_loaded, mask_loaded
if matlab then
img_loaded = mattorch.load(self.rawdata[index].imgfile)
mask_loaded = mattorch.load(self.rawdata[index].maskfile)
for _,matrix in pairs(img_loaded) do
img_loaded = matrix
break
end
for _,matrix in pairs(mask_loaded) do
mask_loaded = matrix
break
end
img_loaded = img_loaded:transpose(2,3)
mask_loaded = mask_loaded:transpose(1,2)
else
img_loaded = image.load(self.rawdata[index].imgfile)
mask_loaded = image.load(self.rawdata[index].maskfile)[1]
end
-- resize ?
if self.rawSampleSize then
-- resize precisely
local w = self.rawSampleSize.w
local h = self.rawSampleSize.h
self.currentSample = torch.Tensor(img_loaded:size(1),h,w)
image.scale(img_loaded, self.currentSample, 'bilinear')
self.currentMask = torch.Tensor(h,w)
image.scale(mask_loaded, self.currentMask, 'simple')
elseif self.rawSampleMaxSize and (self.rawSampleMaxSize < img_loaded:size(3)
or self.rawSampleMaxSize < img_loaded:size(2)) then
-- resize to fit in bounding box
local w,h
if img_loaded:size(3) >= img_loaded:size(2) then
w = self.rawSampleMaxSize
h = math.floor((w*img_loaded:size(2))/img_loaded:size(3))
else
h = self.rawSampleMaxSize
w = math.floor((h*img_loaded:size(3))/img_loaded:size(2))
end
self.currentSample = torch.Tensor(img_loaded:size(1),h,w)
image.scale(img_loaded, self.currentSample, 'bilinear')
self.currentMask = torch.Tensor(h,w)
image.scale(mask_loaded, self.currentMask, 'simple')
else
self.currentSample = img_loaded
self.currentMask = mask_loaded
end
-- process mask
if matlab then
if self.currentMask:min() == 0 then
self.currentMask:add(1)
end
elseif self.rawMaskRescale then
-- stanford dataset style (png contains 0 and 255)
self.currentMask:mul(self.nbClasses-1):add(0.5):floor():add(1)
else
-- PNG already stores values at the correct classes
-- only holds values from 0 to nclasses
self.currentMask:mul(255):add(1):add(0.5):floor()
end
self.currentIndex = index
end
end
function DataSetSamplingPascal:preload(saveFile)
-- if cache file exists, just retrieve images from it
if self.cacheFile
and paths.filep(paths.concat(self.path,self.cacheFile..'-samples')) then
print('<DataSetSamplingPascal> retrieving saved samples from :'
.. paths.concat(self.path,self.cacheFile..'-samples')
.. ' [delete file to force new scan]')
local file = torch.DiskFile(paths.concat(self.path,self.cacheFile..'-samples'), 'r')
file:binary()
self.preloaded = file:readObject()
file:close()
self.preloadedDone = true
return
end
print('<DataSetSamplingPascal> preloading all images')
self.preloaded = {samples={}, masks={}}
for i = 1,self.nbRawSamples do
xlua.progress(i,self.nbRawSamples)
-- load samples, and store them in raw byte tensors (min memory footprint)
self:loadSample(i)
local rawTensor = torch.Tensor(self.currentSample:size()):copy(self.currentSample)
local rawMask = torch.Tensor(self.currentMask:size()):copy(self.currentMask)
-- insert them in our list
table.insert(self.preloaded.samples, rawTensor)
table.insert(self.preloaded.masks, rawMask)
end
self.preloadedDone = true
-- optional cache file
if saveFile then
self.cacheFile = saveFile
end
-- if cache file given, serialize list of tags to it
if self.cacheFile then
print('<DataSetSamplingPascal> saving samples to cache file: '
.. paths.concat(self.path,self.cacheFile..'-samples'))
local file = torch.DiskFile(paths.concat(self.path,self.cacheFile..'-samples'), 'w')
file:binary()
file:writeObject(self.preloaded)
file:close()
end
end
function DataSetSamplingPascal:parseMask(existing_tags)
local tags
if not existing_tags then
tags = {}
local storage
for i = 1,self.nbClasses do
storage = torch.ShortStorage(self.rawSampleMaxSize*self.rawSampleMaxSize*4)
tags[i] = {data=storage, size=0}
end
else
tags = existing_tags
end
-- use filter
local filter = self.samplingFilter or {ratio=0, size=self.patchSize, step=4}
-- extract labels
local mask = self.currentMask
local x_start = math.ceil(self.patchSize/2)
local x_end = mask:size(2) - math.ceil(self.patchSize/2)
local y_start = math.ceil(self.patchSize/2)
local y_end = mask:size(1) - math.ceil(self.patchSize/2)
local file = sys.concat(self.realIndex:gsub('Images','Segments'),'.mat')
local mat_path = file:gsub('/.mat$','.mat')
local loaded = mattorch.load(mat_path)
loaded = loaded.top_masks:float()
local segment1, segmenttmp
nb_segments = self.nbSegments
if self.nbSegments > loaded:size(1) then nb_segments = loaded:size(1) end
-- (1) load a random segment
-- for i=1,self.nbPatchPerSample do
mask:add(-1)
local nb_segs = self.nbSegmentsToExtract
step = torch.ceil(math.sqrt(self.nbSegmentsToExtract))
for ids = 1, nb_segs do
-- make sure each tag list is large enough to hold the incoming data
for i = 1,self.nbClasses do
if ((tags[i].size + (self.rawSampleMaxSize*self.rawSampleMaxSize*4)) >
tags[i].data:size()) then
tags[i].data:resize(tags[i].size+(self.rawSampleMaxSize*self.rawSampleMaxSize*4))
end
end
-- sample one segment
--k = math.random(nb_segments)
segment1 = loaded[ids]:t()
-- (2) mask the ground truth mask with the random segment.
segment1:cmul(mask):add(1)
self.currentSegment = ids
mask.nn.DataSetSegmentSampling_extract(tags, segment1,
x_start, x_end,
y_start, y_end, self.currentIndex, self.currentSegment,
filter.ratio, filter.size, filter.step, step)
end
return tags
end
function DataSetSamplingPascal:parseAllMasks(saveFile)
-- if cache file exists, just retrieve tags from it
if self.cacheFile and paths.filep(paths.concat(self.path,self.cacheFile..'-tags')) then
print('<DataSetSamplingPascal> retrieving saved tags from :' .. paths.concat(self.path,self.cacheFile..'-tags')
.. ' [delete file to force new scan]')
local file = torch.DiskFile(paths.concat(self.path,self.cacheFile..'-tags'), 'r')
file:binary()
self.tags = file:readObject()
file:close()
return
end
-- parse tags, long operation
print('<DataSetSamplingPascal> parsing all masks to generate list of tags')
print('<DataSetSamplingPascal> WARNING: this operation could allocate up to '..
math.ceil(self.nbRawSamples*self.rawSampleMaxSize*self.rawSampleMaxSize*
4*2/1024/1024)..'MB')
self.tags = nil
for i = 1,self.nbRawSamples do
xlua.progress(i,self.nbRawSamples)
self:loadSample(i)
self.tags = self:parseMask(self.tags)
end
-- report
print('<DataSetSamplingPascal> nb of patches extracted per category:')
for i = 1,self.nbClasses do
print(' ' .. i .. ' - ' .. self.tags[i].size / 4)
end
-- optional cache file
if saveFile then
self.cacheFile = saveFile
end
-- if cache file exists, serialize list of tags to it
if self.cacheFile then
print('<DataSetSamplingPascal> saving tags to cache file: ' .. paths.concat(self.path,self.cacheFile..'-tags'))
local file = torch.DiskFile(paths.concat(self.path,self.cacheFile..'-tags'), 'w')
file:binary()
file:writeObject(self.tags)
file:close()
end
end
function DataSetSamplingPascal:display(...)
-- check args
local _, title, samples, zoom = xlua.unpack(
{...},
'DataSetSamplingPascal.display',
'display masks, overlayed on dataset images',
{arg='title', type='string', help='window title', default='DataSetSamplingPascal'},
{arg='samples', type='number', help='number of samples to display', default=50},
{arg='zoom', type='number', help='zoom', default=0.5}
)
-- require imgraph package to handle segmentation colors
require 'imgraph'
-- load the samples and display them
local allimgs = {}
for i=1,samples do
self:loadSample(i)
local dispTensor = self.currentSample:clone()
local dispMask = self.currentMask:clone()
if dispTensor:size(1) > 3 and dispTensor:nDimension() == 3 then
dispTensor = dispTensor:narrow(1,1,3)
end
dispTensor:div(dispTensor:max())
dispMask, self.colormap = imgraph.colorize(dispMask, self.colormap)
dispTensor:add(dispMask)
allimgs[i] = dispTensor
end
-- display
image.display{win=painter, image=allimgs, legend=title, zoom=0.5}
end