-
Notifications
You must be signed in to change notification settings - Fork 0
/
dataloader.lua
103 lines (90 loc) · 2.82 KB
/
dataloader.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
-- dataloader.lua
local datasets = require 'datasets/init'
local Threads = require 'threads'
Threads.serialization('threads.sharedserialize')
local M = {}
local DataLoader = torch.class('resnet.DataLoader', M)
function DataLoader.create(opt)
-- The train and val loader
local loaders = {}
for i, split in ipairs{'train', 'val'} do
local dataset = datasets.create(opt, split)
loaders[i] = M.DataLoader(dataset, opt, split)
end
return table.unpack(loaders)
end
function DataLoader:__init(dataset, opt, split)
local function init()
require('datasets/' .. opt.dataset)
end
local function main(idx)
torch.setnumthreads(1)
_G.dataset = dataset
_G.preprocess = dataset:preprocess()
return dataset:size()
end
local threads, sizes = Threads(opt.nThreads, init, main)
self.nCrops = (split == 'val' and opt.tenCrop) and 10 or 1
self.threads = threads
self.__size = sizes[1][1]
self.batchSize = math.floor(opt.batchSize / self.nCrops)
end
function DataLoader:size()
return math.ceil(self.__size / self.batchSize)
end
function DataLoader:run()
local threads = self.threads
local size, batchSize = self.__size, self.batchSize
local perm = torch.randperm(size)
local idx, sample = 1, nil
local function enqueue()
while idx <= size and threads:acceptsjob() do
local indices = perm:narrow(1, idx, math.min(batchSize, size - idx + 1))
threads:addjob(
function(indices, nCrops)
local sz = indices:size(1)
local batch, imageSize
local target = torch.IntTensor(sz)
for i, idx in ipairs(indices:totable()) do
local sample = _G.dataset:get(idx)
local input = _G.preprocess(sample.input)
if not batch then
imageSize = input:size():totable()
if nCrops > 1 then table.remove(imageSize, 1) end
batch = torch.FloatTensor(sz, nCrops, table.unpack(imageSize))
end
batch[i]:copy(input)
target[i] = sample.target
end
collectgarbage()
return {
input = batch:view(sz * nCrops, table.unpack(imageSize)),
target = target,
}
end,
function(_sample_)
sample = _sample_
end,
indices,
self.nCrops
)
idx = idx + batchSize
end
end
local n = 0
local function loop()
enqueue()
if not threads:hasjob() then
return nil
end
threads:dojob()
if threads:haserror() then
threads:synchronize()
end
enqueue()
n = n + 1
return n, sample
end
return loop
end
return M.DataLoader