8
8
9
9
MAA_VISION_NS_BEGIN
10
10
11
- std::pair<NeuralNetworkClassifier::ResultsVec, size_t > NeuralNetworkClassifier::analyze () const
11
+ NeuralNetworkClassifier::NeuralNetworkClassifier (
12
+ cv::Mat image,
13
+ NeuralNetworkClassifierParam param,
14
+ std::shared_ptr<Ort::Session> session,
15
+ std::string name)
16
+ : VisionBase(std::move(image), std::move(name))
17
+ , param_(std::move(param))
18
+ , session_(std::move(session))
19
+ {
20
+ analyze ();
21
+ }
22
+
23
+ void NeuralNetworkClassifier::analyze ()
12
24
{
13
25
LogFunc << name_;
14
26
15
27
if (!session_) {
16
28
LogError << " OrtSession not loaded" ;
17
- return {} ;
29
+ return ;
18
30
}
19
31
if (param_.cls_size == 0 ) {
20
32
LogError << " cls_size == 0" ;
21
- return {} ;
33
+ return ;
22
34
}
23
35
if (param_.cls_size != param_.labels .size ()) {
24
36
LogError << " cls_size != labels.size()" << VAR (param_.cls_size )
25
37
<< VAR (param_.labels .size ());
26
- return {} ;
38
+ return ;
27
39
}
28
40
29
41
auto start_time = std::chrono::steady_clock::now ();
30
- ResultsVec results = foreach_rois ();
31
- auto cost = duration_since (start_time);
32
- LogTrace << name_ << " Raw:" << VAR (results) << VAR (cost);
33
42
34
- const auto & expected = param_. expected ;
35
- filter ( results, expected);
43
+ auto results = classify_all_rois () ;
44
+ add_results ( std::move ( results), param_. expected );
36
45
37
- cost = duration_since (start_time);
38
- LogTrace << name_ << " Filter:" << VAR (results) << VAR (expected) << VAR (cost);
46
+ sort ();
39
47
40
- sort (results);
41
- size_t index = preferred_index (results);
42
- return { results, index };
48
+ auto cost = duration_since (start_time);
49
+ LogTrace << name_ << VAR (all_results_) << VAR (filtered_results_) << VAR (cost);
43
50
}
44
51
45
- NeuralNetworkClassifier::ResultsVec NeuralNetworkClassifier::foreach_rois () const
52
+ NeuralNetworkClassifier::ResultsVec NeuralNetworkClassifier::classify_all_rois ()
46
53
{
47
54
if (param_.roi .empty ()) {
48
55
return { classify (cv::Rect (0 , 0 , image_.cols , image_.rows )) };
49
56
}
50
-
51
- ResultsVec results;
52
- for (const cv::Rect& roi : param_.roi ) {
53
- Result res = classify (roi);
54
- results.emplace_back (std::move (res));
57
+ else {
58
+ ResultsVec results;
59
+ for (const cv::Rect& roi : param_.roi ) {
60
+ Result res = classify (roi);
61
+ results.emplace_back (std::move (res));
62
+ }
63
+ return results;
55
64
}
56
-
57
- return results;
58
65
}
59
66
60
- NeuralNetworkClassifier::Result NeuralNetworkClassifier::classify (const cv::Rect& roi) const
67
+ NeuralNetworkClassifier::Result NeuralNetworkClassifier::classify (const cv::Rect& roi)
61
68
{
62
69
if (!session_) {
63
70
LogError << " OrtSession not loaded" ;
@@ -114,17 +121,33 @@ NeuralNetworkClassifier::Result NeuralNetworkClassifier::classify(const cv::Rect
114
121
result.label = param_.labels [result.cls_index ];
115
122
result.box = roi;
116
123
117
- draw_result (result);
124
+ if (debug_draw_) {
125
+ auto draw = draw_result (result);
126
+ handle_draw (draw);
127
+ }
118
128
119
129
return result;
120
130
}
121
131
122
- void NeuralNetworkClassifier::draw_result ( const Result& res) const
132
+ void NeuralNetworkClassifier::add_results (ResultsVec results, const std::vector< size_t >& expected)
123
133
{
124
- if (!debug_draw_) {
125
- return ;
126
- }
134
+ std::ranges::copy_if (results, std::back_inserter (filtered_results_), [&](const auto & res) {
135
+ return std::ranges::find (expected, res.cls_index ) != expected.end ();
136
+ });
137
+
138
+ merge_vector_ (all_results_, std::move (results));
139
+ }
140
+
141
+ void NeuralNetworkClassifier::sort ()
142
+ {
143
+ sort_ (all_results_);
144
+ sort_ (filtered_results_);
145
+
146
+ handle_index (filtered_results_.size (), param_.result_index );
147
+ }
127
148
149
+ cv::Mat NeuralNetworkClassifier::draw_result (const Result& res) const
150
+ {
128
151
cv::Mat image_draw = draw_roi (res.box );
129
152
cv::Point pt (res.box .x + res.box .width + 5 , res.box .y + 20 );
130
153
@@ -140,21 +163,10 @@ void NeuralNetworkClassifier::draw_result(const Result& res) const
140
163
pt.y += 20 ;
141
164
}
142
165
143
- handle_draw (image_draw);
144
- }
145
-
146
- void NeuralNetworkClassifier::filter (ResultsVec& results, const std::vector<size_t >& expected) const
147
- {
148
- if (expected.empty ()) {
149
- return ;
150
- }
151
-
152
- std::erase_if (results, [&](const Result& res) {
153
- return std::find (expected.begin (), expected.end (), res.cls_index ) == expected.end ();
154
- });
166
+ return image_draw;
155
167
}
156
168
157
- void NeuralNetworkClassifier::sort (ResultsVec& results) const
169
+ void NeuralNetworkClassifier::sort_ (ResultsVec& results) const
158
170
{
159
171
switch (param_.order_by ) {
160
172
case ResultOrderBy::Horizontal:
@@ -178,14 +190,4 @@ void NeuralNetworkClassifier::sort(ResultsVec& results) const
178
190
}
179
191
}
180
192
181
- size_t NeuralNetworkClassifier::preferred_index (const ResultsVec& results) const
182
- {
183
- auto index_opt = pythonic_index (results.size (), param_.result_index );
184
- if (!index_opt) {
185
- return SIZE_MAX;
186
- }
187
-
188
- return *index_opt;
189
- }
190
-
191
- MAA_VISION_NS_END
193
+ MAA_VISION_NS_END
0 commit comments