This repository has been archived by the owner on Jul 12, 2020. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 10
/
Copy pathprovider_yolo_1x1.lua
93 lines (69 loc) · 2.54 KB
/
provider_yolo_1x1.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
--
-- Created by IntelliJ IDEA.
-- User: changqi
-- Date: 3/14/16
-- Time: 10:05 AM
--
require 'nn';
require 'image';
require 'xlua';
require 'math';
local class = require 'class'
require 'mattorch'
local matDataPrefix = '';
Provider = class('Provider');
function Provider:__init(trainSize, testSize)
print '==> load dataset into trainData/testData'
allImages = mattorch.load('matlab/results/img_0721_1x1.mat');
allImagesWeight = mattorch.load('matlab/results/label_0721_1x1.mat');
print '==> finish load train data, start load test data...';
testImages = mattorch.load('matlab/results/testimg_0721_1x1.mat');
testImagesWeight = mattorch.load('matlab/results/testlabel_0721_1x1.mat');
allImages.trData = allImages.imagesList:t();
allImagesWeight.trWeight = allImagesWeight.imagesLabel:transpose(4,1):transpose(2,3);
testImages.teData = testImages.imagesList:t();
testImagesWeight.teWeight = testImagesWeight.imagesLabel:transpose(4,1):transpose(2,3);
print '==> finish load dataset, start clean data...'
trainSize = allImages.trData:size(1);
testSize = testImages.teData:size(1);
imgWidth = 448;
imgHeight = 448;
s = 7;
-- resize and clean data.
allImages.trData = allImages.trData:reshape(trainSize,1,imgHeight,imgWidth):float();
allImagesWeight.trWeight = allImagesWeight.trWeight:float();
testImages.teData = testImages.teData:reshape(testSize,1,imgHeight,imgWidth):float();
testImagesWeight.teWeight = testImagesWeight.teWeight:float();
self.trainData = {
data = allImages.trData,
labels = allImagesWeight.trWeight,
size = function() return trainSize end
}
self.testData = {
data = testImages.teData,
origData = testImages.teData:clone(),
labels = testImagesWeight.teWeight,
size = function() return testSize end
}
-- allCropImages.trData[{{mask},{}}]
end
function Provider:normalize()
local trainData = self.trainData;
local testData = self.testData;
collectgarbage();
print '==> pre-processing data'
local mean = trainData.data:select(2, 1):mean();
local std = trainData.data:select(2, 1):std();
trainData.data:select(2, 1):add(-mean);
trainData.data:select(2, 1):div(std);
trainData.mean = mean;
trainData.std = std;
testData.data:select(2, 1):add(-mean);
testData.data:select(2, 1):div(std);
testData.mean = mean;
testData.std = std;
end
provider = Provider();
provider:normalize();
--print '==> save provider.t7'
--torch.save('provider_yolo.t7', provider);