forked from SenderOK/RoadSigns
-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.cpp
367 lines (303 loc) · 11.6 KB
/
main.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
#include <string>
#include <vector>
#include <fstream>
#include <cassert>
#include <iostream>
#include <cmath>
#include "classifier.h"
#include "EasyBMP/EasyBMP.h"
#include "liblinear-1.93/linear.h"
#include "argvparser/argvparser.h"
#include "io.h"
#include "filters.h"
using std::string;
using std::vector;
using std::ifstream;
using std::ofstream;
using std::pair;
using std::make_pair;
using std::cout;
using std::cerr;
using std::endl;
using CommandLineProcessing::ArgvParser;
//_______________________________________________________________________________________
//PARAMETERS ARE DESCRIBED HERE
//_______________________________________________________________________________________
const int N_DIRECTIONS = 13;
const double OVERLAP = 0.7;
const int BLOCK_SIZE = 4;
const int N_BLOCKS = 3;
const int METABLOCK_STRIDE = BLOCK_SIZE * N_BLOCKS * (1 - OVERLAP);
const Matrix<double> filter_x = Matrix<double>({ {-1, 0, 1}, {-2, 0, 2}, {-1, 0, 1} });
const Matrix<double> filter_y = Matrix<double>({ {1, 2, 1}, {0, 0, 0}, {-1, -2, -1} });
const int SVM_SOLVER = L2R_L2LOSS_SVC_DUAL;
const double SVM_C = 0.2;
enum NORM {
L1,
L2,
LINF,
L2Hys
};
const int CURR_NORM = L2Hys;
const bool KERNEL_IS_ON = false;
const float KERNEL_L = 0.25;
const int KERNEL_N = 1;
//________________________________________________________________________________________
//________________________________________________________________________________________
const double PI = 3.14159265358979323;
typedef vector<pair<string, int> > TFileList;
typedef vector<pair<vector<float>, int> > TFeatures;
// Load list of files and its labels from 'data_file' and
// stores it in 'file_list'
void LoadFileList(const string& data_file, TFileList* file_list) {
ifstream stream(data_file.c_str());
string filename;
int label;
int char_idx = data_file.size() - 1;
for (; char_idx >= 0; --char_idx)
if (data_file[char_idx] == '/' || data_file[char_idx] == '\\')
break;
string data_path = data_file.substr(0,char_idx+1);
while(!stream.eof() && !stream.fail()) {
stream >> filename >> label;
if (filename.size())
file_list->push_back(make_pair(data_path + filename, label));
}
stream.close();
}
// Save result of prediction to file
void SavePredictions(const TFileList& file_list,
const TLabels& labels,
const string& prediction_file) {
// Check that list of files and list of labels has equal size
assert(file_list.size() == labels.size());
// Open 'prediction_file' for writing
ofstream stream(prediction_file.c_str());
// Write file names and labels to stream
for (size_t image_idx = 0; image_idx < file_list.size(); ++image_idx)
stream << file_list[image_idx].first << " " << labels[image_idx] << endl;
stream.close();
}
vector<float> get_histogram(const Matrix<double> &image)
{
Matrix<double> sobel_x = image.unary_map(CustomFilter(filter_x));
Matrix<double> sobel_y = image.unary_map(CustomFilter(filter_y));
Matrix<double> grad_vals = binary_map(CalcGradValFilter(), sobel_y, sobel_x);
Matrix<double> grad_dirs = binary_map(CalcGradDirFilter(), sobel_y, sobel_x);
vector<float> histogram(N_DIRECTIONS, 0);
for (uint i = 0; i < image.n_rows; ++i) {
for (uint j = 0; j < image.n_cols; ++j) {
uint sector_num = uint(((grad_dirs(i, j) / PI + 1) / 2) * N_DIRECTIONS);
histogram[(sector_num == N_DIRECTIONS) ? N_DIRECTIONS - 1 : sector_num] += grad_vals(i, j);
}
}
return histogram;
}
const float EPS = 0.0001;
void NormL1(vector<float> &v)
{
float sum = 0;
for (auto it = v.begin(); it < v.end(); ++it)
sum += *it;
sum += EPS;
for (auto it = v.begin(); it < v.end(); ++it)
*it /= sum;
}
void NormL2(vector<float> &v)
{
float sum_sqr = 0;
for (auto it = v.begin(); it < v.end(); ++it)
sum_sqr += (*it) * (*it);
sum_sqr = sqrt(sum_sqr + EPS * EPS);
for (auto it = v.begin(); it < v.end(); ++it)
*it /= sum_sqr;
}
void NormL2Hys(vector<float> &v)
{
float sum_sqr = 0;
for (auto it = v.begin(); it < v.end(); ++it)
sum_sqr += (*it) * (*it);
sum_sqr = sqrt(sum_sqr + EPS * EPS);
for (auto it = v.begin(); it < v.end(); ++it) {
*it /= sum_sqr;
if (*it > 0.2) {
*it = 0.2;
}
}
NormL2(v);
}
void NormLINF(vector<float> &v)
{
float max_element = 0;
for (auto it = v.begin(); it < v.end(); ++it)
max_element = std::max(max_element, *it);
if (max_element < EPS)
return;
for (auto it = v.begin(); it < v.end(); ++it)
*it /= max_element;
}
vector<float> chi_sqare_kernel(float x)
{
vector<float> result;
result.reserve(2 * (2 * KERNEL_N + 1));
if (x < EPS) {
for (int i = -KERNEL_N; i <= KERNEL_N; ++i) {
result.push_back(0);
result.push_back(0);
}
} else {
for (int i = -KERNEL_N; i <= KERNEL_N; ++i) {
float lambda = KERNEL_L * i;
float coeff = sqrt((x / cosh(PI * lambda)));
float tmp = lambda * log(x);
result.push_back(cos(tmp) * coeff);
result.push_back(-sin(tmp) * coeff);
}
}
return result;
}
// Extract features from dataset.
void ExtractFeatures(const TFileList& file_list, TFeatures& features) {
typedef void (*norm_ptr) (vector<float> &);
norm_ptr functions[] =
{
NormL1,
NormL2,
NormLINF,
NormL2Hys,
};
const norm_ptr norm_f = functions[CURR_NORM];
features.reserve(file_list.size());
for (size_t image_idx = 0; image_idx < file_list.size(); ++image_idx) {
PreciseImage im = load_image<double>(file_list[image_idx].first.c_str());
Matrix<double> image = im.unary_map(WeightPixelFilter(0.299, 0.587, 0.114));
vector<float> one_image_features;
for (uint i = 0; i + N_BLOCKS * BLOCK_SIZE < image.n_rows; i += METABLOCK_STRIDE) {
for (uint j = 0; j + N_BLOCKS * BLOCK_SIZE < image.n_cols; j += METABLOCK_STRIDE) {
vector<float> one_metablock_features;
one_metablock_features.reserve(N_BLOCKS * N_BLOCKS * N_DIRECTIONS * ((KERNEL_IS_ON) ? 2 * (2 * KERNEL_N + 1) : 1));
for (uint row = 0; row < N_BLOCKS; ++row) {
for (uint col = 0; col < N_BLOCKS; ++col) {
Matrix<double> part = image.submatrix(i + row * BLOCK_SIZE, j + col * BLOCK_SIZE, BLOCK_SIZE, BLOCK_SIZE);
vector<float> one_block_hist;
if (KERNEL_IS_ON) {
vector<float> hist = get_histogram(part);
one_block_hist.reserve(2 * (2 * KERNEL_N + 1) * hist.size());
for (int k = 0; k < hist.size(); ++k) {
vector<float> tmp = chi_sqare_kernel(hist[k]);
one_block_hist.insert(one_block_hist.end(), tmp.begin(), tmp.end());
}
} else {
one_block_hist = get_histogram(part);
}
one_metablock_features.insert(one_metablock_features.end(), one_block_hist.begin(), one_block_hist.end());
}
}
norm_f(one_metablock_features);
one_image_features.insert(one_image_features.end(), one_metablock_features.begin(), one_metablock_features.end());
}
}
features.push_back(make_pair(one_image_features, file_list[image_idx].second));
}
}
// Train SVM classifier using data from 'data_file' and save trained model
// to 'model_file'
void TrainClassifier(const string& data_file, const string& model_file) {
// List of image file names and its labels
TFileList file_list;
// Structure of images and its labels
TFeatures features;
// Model which would be trained
TModel model;
// Parameters of classifier
TClassifierParams params;
// Load list of image file names and its labels
LoadFileList(data_file, &file_list);
// Extract features from images
ExtractFeatures(file_list, features);
params.C = SVM_C;
params.solver_type = SVM_SOLVER;
TClassifier classifier(params);
// Train classifier
classifier.Train(features, &model);
// Save model to file
model.Save(model_file);
}
// Predict data from 'data_file' using model from 'model_file' and
// save predictions to 'prediction_file'
void PredictData(const string& data_file,
const string& model_file,
const string& prediction_file) {
// List of image file names and its labels
TFileList file_list;
// Structure of images and its labels
TFeatures features;
// List of image labels
TLabels labels;
// Load list of image file names and its labels
LoadFileList(data_file, &file_list);
// Load images
ExtractFeatures(file_list, features);
// Classifier
TClassifier classifier = TClassifier(TClassifierParams());
// Trained model
TModel model;
// Load model from file
model.Load(model_file);
// Predict images by its features using 'model' and store predictions
// to 'labels'
classifier.Predict(features, model, &labels);
// Save predictions
SavePredictions(file_list, labels, prediction_file);
}
int main(int argc, char** argv) {
// Command line options parser
ArgvParser cmd;
// Description of program
cmd.setIntroductoryDescription("Machine graphics course, task 2. CMC MSU, 2013.");
// Add help option
cmd.setHelpOption("h", "help", "Print this help message");
// Add other options
cmd.defineOption("data_set", "File with dataset",
ArgvParser::OptionRequiresValue | ArgvParser::OptionRequired);
cmd.defineOption("model", "Path to file to save or load model",
ArgvParser::OptionRequiresValue | ArgvParser::OptionRequired);
cmd.defineOption("predicted_labels", "Path to file to save prediction results",
ArgvParser::OptionRequiresValue);
cmd.defineOption("train", "Train classifier");
cmd.defineOption("predict", "Predict dataset");
// Add options aliases
cmd.defineOptionAlternative("data_set", "d");
cmd.defineOptionAlternative("model", "m");
cmd.defineOptionAlternative("predicted_labels", "l");
cmd.defineOptionAlternative("train", "t");
cmd.defineOptionAlternative("predict", "p");
// Parse options
int result = cmd.parse(argc, argv);
// Check for errors or help option
if (result) {
cout << cmd.parseErrorDescription(result) << endl;
return result;
}
// Get values
string data_file = cmd.optionValue("data_set");
string model_file = cmd.optionValue("model");
bool train = cmd.foundOption("train");
bool predict = cmd.foundOption("predict");
// If we need to train classifier
if (train) {
TrainClassifier(data_file, model_file);
}
// If we need to predict data
if (predict) {
// You must declare file to save images
if (!cmd.foundOption("predicted_labels")) {
cerr << "Error! Option --predicted_labels not found!" << endl;
return 1;
}
// File to save predictions
string prediction_file = cmd.optionValue("predicted_labels");
// Predict data
PredictData(data_file, model_file, prediction_file);
}
}